package com.chatplus.application.aiprocessor.platform.chat;

import cn.bugstack.openai.executor.parameter.CompletionRequest;
import cn.bugstack.openai.executor.parameter.Functions;
import cn.bugstack.openai.executor.parameter.Message;
import cn.bugstack.openai.executor.parameter.RequestChannel;
import cn.bugstack.openai.session.OpenAiSession;
import cn.hutool.json.JSONObject;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.chatplus.application.aiprocessor.util.ChatWebSocketUtil;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.dto.ApiKeyDto;
import com.chatplus.application.domain.dto.FunctionParametersDto;
import com.chatplus.application.domain.entity.chat.ChatModelEntity;
import com.chatplus.application.domain.entity.functions.FunctionEntity;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.service.functions.FunctionService;
import com.chatplus.application.util.ConfigUtil;
import com.google.common.collect.Lists;
import lombok.Setter;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * Chat 类型 AI 处理器服务接口
 */
public abstract class ChatProcessorService {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(ChatProcessorService.class);
    private FunctionService functionService;

    @Setter
    protected ChatWebSocketRequest chatRequest;

    protected List<ApiKeyDto> openApiKey = new ArrayList<>();

    @Autowired
    public void setFunctionService(FunctionService functionService) {
        this.functionService = functionService;
    }

    /**
     * 流式输出
     */
    public abstract void processStream() throws Exception;

    /**
     * 处理器,阻塞式输出，直接返回
     */
    public abstract String processSync();

    /**
     * AI 渠道处理器
     */
    public abstract AiPlatformEnum getPlatform();

    public RequestChannel castChannel() {
        return switch (getPlatform()) {
            case OPEN_AI -> RequestChannel.OPEN_AI;
            case CHAT_GLM -> RequestChannel.CHAT_GLM;
            case ALI_YUN_Q_WEN -> RequestChannel.ALI_YUN_Q_WEN;
            case BAI_DU -> RequestChannel.BAI_DU;
            case XUN_FEI -> RequestChannel.XUN_FEI;
            default -> throw new BadRequestException("未知的 AI 渠道");
        };
    }

    /**
     * 函数设置
     *
     * @return 返回函数列表
     */
    public List<Functions> getFunctionList() {

        if (chatRequest.getChatModel().getEnabledFunction() == Boolean.FALSE) {
            return Lists.newArrayList();
        }
        List<FunctionEntity> functionList = functionService.list(Wrappers.<FunctionEntity>lambdaQuery().eq(FunctionEntity::getEnabled, Boolean.TRUE));
        if (CollectionUtils.isEmpty(functionList)) {
            return Lists.newArrayList();
        }
        return functionList.stream().map(functionEntity -> {
            FunctionParametersDto functionParametersDto = functionEntity.getParameters();
            Map<String, FunctionParametersDto.FunctionFieldDetailBean> propertiesMap
                    = functionParametersDto.getProperties();
            JSONObject properties = new JSONObject();
            if (propertiesMap != null && !propertiesMap.isEmpty()) {
                propertiesMap.forEach((k, v) -> {
                    JSONObject propertiesField = new JSONObject();
                    propertiesField.putOpt("type", v.getType());
                    propertiesField.putOpt("description", v.getDescription());
                    properties.putOpt(k, propertiesField);
                });
            }
            Functions.Parameters parameters = Functions.Parameters.builder()
                    .type("object")
                    .properties(properties)
                    .required(functionParametersDto.getRequired()).build();
            return Functions.builder()
                    .name(functionEntity.getName())
                    .description(functionEntity.getDescription())
                    .parameters(parameters)
                    .build();
        }).toList();
    }

    public CompletionRequest getCompletionRequest() {
        // 组装请求参数
        List<Message> messages = new ArrayList<>();
        chatRequest.getChatContextList().forEach(msg -> {
            if (StringUtils.isNotEmpty(msg.getSystem())) {
                Message promptMessage = Message.builder().role(CompletionRequest.Role.SYSTEM).content(msg.getSystem()).build();
                messages.add(promptMessage);
            }
            if (StringUtils.isNotEmpty(msg.getPrompt())) {
                Message promptMessage = Message.builder().role(CompletionRequest.Role.USER).content(msg.getPrompt()).build();
                messages.add(promptMessage);
            }
            if (StringUtils.isNotEmpty(msg.getReply())) {
                Message replyMessage = Message.builder().role(CompletionRequest.Role.ASSISTANT).content(msg.getReply()).build();
                messages.add(replyMessage);
            }
        });
        ChatModelEntity chatModel = chatRequest.getChatModel();
        // 1. 创建参数
        return CompletionRequest.builder()
                .stream(true)
                .messages(messages)
                .functions(getFunctionList())
                .user(chatRequest.getUserId().toString())
                .tag(chatRequest.getSessionId())
                .model(chatModel.getValue())
                .channel(castChannel())
                .maxTokens(chatModel.getMaxTokens() != null ? chatModel.getMaxTokens() : 2048)
                .temperature(chatModel.getTemperature() != null ? chatModel.getTemperature() : 0.2f)
                .build();
    }

    /**
     * 初始化会话工厂
     */
    public abstract OpenAiSession getSessionFactory();

    /**
     * 初始化每个平台的 API KEY
     */
    public synchronized void instance() {
        openApiKey = ConfigUtil.getChatApiKey(getPlatform(), chatRequest.getChannel());
        // // 更新 API KEY 的最后使用时间
        if (CollectionUtils.isEmpty(openApiKey)) {
            ChatWebSocketUtil.replyErrorMessage(chatRequest.getSessionId(), "抱歉😔😔😔，系统已经没有可用的 API KEY，请联系管理员！");
        }
    }
}
